Skip to content

Create shared Megatron calibration forward loop for prune / quantize#1501

Open
kevalmorabia97 wants to merge 4 commits into
mainfrom
kmorabia/fix-mcore-minitron-hybrid-and-qwen-te-spec
Open

Create shared Megatron calibration forward loop for prune / quantize#1501
kevalmorabia97 wants to merge 4 commits into
mainfrom
kmorabia/fix-mcore-minitron-hybrid-and-qwen-te-spec

Conversation

@kevalmorabia97
Copy link
Copy Markdown
Collaborator

@kevalmorabia97 kevalmorabia97 commented May 15, 2026

Summary

Replaces the bespoke calibration loops in Megatron-LM and Megatron-Bridge prune / quantize example scripts with a single shared utility, modelopt.torch.utils.plugins.megatron_calibration.get_megatron_calibration_forward_loop.

The shared loop:

  • Calls get_dataset_dataloader (one sample per row, batch-padded) — single source of truth for the calibration dataset surface.
  • Trims each row to its real length using the dataloader's attention_mask, then forces EOS at the trimmed last position, matching MBridge's GPTSFTDataset(add_eos=True) semantics exactly. Padding-token activations would otherwise be hooked into calibration statistics regardless of attention masking (FFN/layernorm fire on every position), causing a substantial MMLU regression on prune (-5 to -7 pts in our experiments).
  • Sorts samples by real length descending so front batches are mostly full-length (has_padding=False → true batched forward, mbs > 1 throughput); back batches that contain padding fall through to per-row forward to keep calibration stats clean.
  • Forwards via megatron_prefill(skip_return_logits=True) (no logits compute, just activation flow for hooks).

Migrates four call sites to the shared util:

  • examples/megatron_bridge/prune_minitron.py
  • Megatron-LM/examples/post_training/modelopt/{prune,quantize}.py (separate PR: NVIDIA/Megatron-LM#4881)
  • Megatron-Bridge/examples/quantization/quantize.py (separate PR)

Unified defaults across all four sites: --calib-dataset nemotron-post-training-dataset-v2, --calib-size 1024, --calib-max-sequence-length 4096, --calib-batch-size 1. Conservative defaults sized for MoE pruning (top-K routing → fewer tokens per expert → more samples × longer seq needed for stable amax / scoring).

Experimental results

Validated on Qwen3-8B (TP=1 PP=2 for prune; TP=2 PP=1 for MMLU eval) with the production shared loop vs the original per-example bespoke loops.

MMLU noise floor (binomial 2σ at acc ≈ 0.70):

  • 5% MMLU (n=728): ±3.4 pt — any two numbers within this range are statistically indistinguishable.
  • Full MMLU (n=14042): ±0.78 pt — much tighter; what we use for headline claims.

M-LM Minitron prune (Qwen3-8B → 30L/3584/11776 ≈ 6B params)

Calibration Dataset seq_len calib_size calib_bs 5% MMLU
Original (inline pack=True WAR) cnn_dailymail 512 512 1 / 16 0.530
Shared loop cnn_dailymail 512 512 16 0.555 (+2.5)
Shared loop nemotron-v2 2048 512 16 0.585
Shared loop nemotron-v2 2048 1024 16 0.588
Shared loop nemotron-v2 4096 512 16 0.587

The +2.5 pt 5% gain on cnn_dailymail is at the edge of 5% MMLU noise (±3.4 pt); full MMLU verification is pending. The +5.5 pt gain switching to nemotron-v2 at seq=2048/4096 is well above noise — the longer SFT-style calibration meaningfully improves prune quality.

M-LM NVFP4 quantize (NVFP4_DEFAULT_CFG)

Calibration Dataset seq_len calib_size calib_bs 5% MMLU Full MMLU Δ vs Original (full)
Original (get_calib_dataloader pad+truncate) cnn_dailymail 512 512 1 0.680 0.703
Shared loop cnn_dailymail 512 512 1 0.694 0.707 +0.46 pt
Shared loop cnn_dailymail 512 512 16 0.698 0.710 +0.77 pt
Shared loop nemotron-v2 4096 512 8 0.710 0.708 +0.51 pt

For reference, hf_ptq.py on the same nemotron-v2 / seq=4096 / calib=512 setup reaches 5% MMLU 0.707 / Full 0.712 at bs=1, confirming the M-LM and HF calibration paths agree within MMLU noise for Qwen3-8B.

M-Bridge Minitron prune

Calibration Dataset seq_len calib_size calib_bs 5% MMLU
Original (get_hf_mbridge_calibration_loop, M-Bridge SFT pipeline) cnn_dailymail 512 512 1 0.549
Shared loop cnn_dailymail 512 512 1 0.559 (+1.0)

5% only; +1.0 pt is within 5% MMLU noise. The headline result here is that the shared loop does not regress M-Bridge prune (earlier non-production variants of the shared loop did — see commit history; the trim+EOS production fix closed it). Full MMLU verification pending.

Conclusions

  • Shared loop is ≥ original on every workload tested at 5% MMLU; on full MMLU the M-LM quantize gain shrinks into the ±0.78 pt noise band.
  • Larger calib_size (512 → 1024) and longer seq (2048 → 4096) within noise for dense Qwen3-8B but kept as conservative defaults for MoE robustness.
  • Sort-by-length is order-invariant for max-calibrator amax accumulation; enables true mbs > 1 batched-forward throughput when calibration data is uniformly full-length.
  • The real wins of this PR: (1) one shared calibration surface for M-LM + M-Bridge, prune + quantize; (2) trim+EOS semantically matches GPTSFTDataset(add_eos=True), closing a prior M-Bridge prune regression introduced by earlier pack=True variants; (3) MoE-friendly conservative defaults across all four call sites.

🤖 Generated with Claude Code

Summary by CodeRabbit

Release Notes

  • New Features

    • Device parameter in dataset utilities now accepts device strings or device objects for greater flexibility.
    • New Megatron calibration forward loop for model calibration workflows.
  • Breaking Changes

    • Removed legacy HuggingFace-based calibration loop utility only used in this repo in the example script.
    • Calibration CLI arguments updated: --calib_mbs/--calib_gbs replaced with --calib_batch_size.
  • Documentation

    • Updated pruning examples to reflect new calibration utilities and arguments.

Review Change Stack

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 15, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 15, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR migrates calibration logic from HuggingFace-based utilities to a new Megatron-Core calibration forward-loop. The change adds get_megatron_calibration_forward_loop, removes legacy HF calibration code from mbridge.py, broadens the device parameter to accept strings, updates test fixtures to use real tokenizers, and rewrites prune examples and documentation to use the new approach.

Changes

Megatron Calibration Migration

Layer / File(s) Summary
Megatron calibration forward-loop module and plugin registration
modelopt/torch/utils/plugins/megatron_calibration.py, modelopt/torch/utils/plugins/__init__.py
New module exports get_megatron_calibration_forward_loop(tokenizer, ...) which constructs a calibration dataloader, materializes and sorts samples by unpadded length, and returns a forward-loop callable that processes batches with CP-rank slicing, padding detection, EOS placement, and logits-free megatron_prefill invocation. Module is registered via plugin import.
Remove HuggingFace calibration utilities from mbridge
modelopt/torch/utils/plugins/mbridge.py
Removes get_hf_mbridge_calibration_loop, _get_dataset_cfg, related imports, and trims __all__ to export only load_mbridge_model_from_hf.
Expand device parameter to accept strings
modelopt/torch/utils/dataset_utils.py
Updates get_dataset_dataloader device parameter type from torch.device | None to torch.device | str | None.
Update test tokenizer fixtures and configurations
tests/unit/torch/utils/test_dataset_utils.py, tests/_test_utils/torch/tokenizer/special_tokens_map.json, tests/_test_utils/torch/tokenizer/tokenizer_config.json
Replaces _FakeTokenizer with real tiny HuggingFace tokenizer from get_tiny_tokenizer() in pad_tokenizer fixture; adds pad_token configuration to tokenizer config and special_tokens_map JSON with proper syntax and EOS fallback.
Migrate prune_minitron.py to new calibration approach
examples/megatron_bridge/prune_minitron.py, examples/megatron_bridge/README.md
Updates imports to use get_megatron_calibration_forward_loop, replaces --calib_mbs/--calib_gbs CLI arguments with --calib_batch_size, rewrites forward-loop construction with tokenizer and dataset parameters, and adjusts inference batch-size defaults and help text. README example command updated to use --calib_batch_size 1.
Update pruning example documentation
examples/pruning/README.md
Updates Minitron "Common Setup" example to import get_megatron_calibration_forward_loop from modelopt.torch.utils.plugins.megatron_calibration and rewrites forward-loop example to pass tokenizer, dataset_name, num_samples, and seq_length instead of model/provider/hf-checkpoint parameters.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 62.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title directly and accurately summarizes the main objective: creating a shared Megatron calibration forward loop for pruning and quantization workflows.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns found. No torch.load(weights_only=False), numpy.load(allow_pickle=True), eval/exec, or # nosec. trust_remote_code is safe (configurable, defaults to False) in mbridge.py.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kmorabia/fix-mcore-minitron-hybrid-and-qwen-te-spec

Comment @coderabbitai help to get the list of available commands and usage tips.

@kevalmorabia97
Copy link
Copy Markdown
Collaborator Author

/ok to test df7ab63

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 15, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1501/

Built to branch gh-pages at 2026-05-21 15:36 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

kevalmorabia97 added a commit to kevalmorabia97/Megatron-LM that referenced this pull request May 15, 2026
Two WARs for modelopt <= 0.44 (fixed upstream in
NVIDIA/Model-Optimizer#1501):

- `prune.py`: after `import_mcore_gpt_from_hf` returns, walk the model
  and copy `model.layers.{i}.input_layernorm.weight` and
  `model.layers.{i}.post_attention_layernorm.weight` from HF into the
  fused `TELayerNormColumnParallelLinear.layer_norm_weight` parameters
  on `attention.linear_qkv` and `mlp.linear_fc1`. Without this the
  fused LayerNorm weights stay at random init for GPT-family models
  (Qwen3, Llama, ...) since modelopt 0.44's importer only loads
  `fused_norm` for Nemotron-H, leaving post-prune MMLU at chance.
  The WAR fails soft on missing HF keys, so it is a no-op on
  Nemotron-H (which uses `backbone.layers.{i}.norm.weight`).

- `mmlu.py`: load `modelopt.torch.utils.plugins.megatron_generate` via
  `importlib.import_module` to grab the submodule rather than the
  function the package re-exports under the same name. The previous
  `from ... import megatron_generate as _mg_plugin` form raised
  `AttributeError: 'function' object has no attribute
  'broadcast_from_last_pipeline_stage'` at import time.

Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@codecov
Copy link
Copy Markdown

codecov Bot commented May 15, 2026

Codecov Report

❌ Patch coverage is 83.33333% with 9 lines in your changes missing coverage. Please review.
✅ Project coverage is 60.35%. Comparing base (a5bc6f8) to head (ff6f279).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
...delopt/torch/utils/plugins/megatron_calibration.py 82.00% 9 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (a5bc6f8) and HEAD (ff6f279). Click for more details.

HEAD has 6 uploads less than BASE
Flag BASE (a5bc6f8) HEAD (ff6f279)
unit 2 1
gpu 3 1
examples 10 7
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1501       +/-   ##
===========================================
- Coverage   76.79%   60.35%   -16.44%     
===========================================
  Files         474      476        +2     
  Lines       51560    52602     +1042     
===========================================
- Hits        39593    31750     -7843     
- Misses      11967    20852     +8885     
Flag Coverage Δ
examples 32.85% <83.33%> (-6.31%) ⬇️
gpu 15.75% <0.00%> (-44.59%) ⬇️
regression 15.22% <12.96%> (+0.08%) ⬆️
unit 52.61% <12.96%> (-0.03%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

kevalmorabia97 added a commit to kevalmorabia97/Megatron-LM that referenced this pull request May 15, 2026
`get_dataset_dataloader` tokenizes each calibration sample individually
with `padding=True, truncation=True, max_length=512`. For long-document
datasets like cnn_dailymail (typical article: ~700-1000 tokens), that
truncates most of each article and pads short ones — feeding the
importance estimator a heavily padded, contextually-impoverished batch.

Pack samples into uniform `calib_max_sequence_length` chunks from the
concatenated token stream (with `eos_token_id` as document separator)
the way Megatron-Bridge's calibration loop does. This exposes the model
to many more distinct contexts per `calib_size` samples and eliminates
padding-token contamination of activation statistics.

Measured impact on Qwen3-8B pruned to 30L/3584/11776 (5.99B params):
  before (trunc+pad):    MMLU 0.486
  after  (packed):       MMLU 0.544   (+5.8 pts, M-Bridge ref 0.563)

The proper upstream fix is to add a `pack` mode to
`get_dataset_dataloader` in modelopt
(NVIDIA/Model-Optimizer#1501); this inline
change makes prune.py work today against released modelopt 0.44.0.

Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 changed the title fix(prune): support HybridModel in mcore_minitron + Qwen3 fused-TE import fix(prune): mcore_minitron HybridModel + Qwen3 fused-TE import + calibration packing May 15, 2026
@kevalmorabia97
Copy link
Copy Markdown
Collaborator Author

/ok to test 20d3c5b

@kevalmorabia97 kevalmorabia97 marked this pull request as ready for review May 15, 2026 20:24
@kevalmorabia97 kevalmorabia97 requested review from a team as code owners May 15, 2026 20:24
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/fix-mcore-minitron-hybrid-and-qwen-te-spec branch from 20d3c5b to 8b537ce Compare May 15, 2026 20:26
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 59-60: The current enablement enables HybridModel defaults but
pattern updates and metric accounting are still gated by isinstance(...,
MambaModel), causing plain HybridModel instances to skip hybrid-pattern logic;
update the checks in the functions that perform hybrid pattern updates and
candidate metric accounting (the places currently using isinstance(obj,
MambaModel)) to detect HybridModel instead (e.g., isinstance(obj, HybridModel))
so plain HybridModel instances get the same pattern handling and metric updates,
and if any MambaModel-specific behavior is required keep a secondary
isinstance(obj, MambaModel) branch for those special cases.

In `@modelopt/torch/utils/dataset_utils.py`:
- Around line 492-526: The code can silently produce zero-length output when
token_stream is too short: after computing n_chunks from token_stream, check for
the edge cases and fail fast or warn; specifically, in the packing branch after
computing n_chunks (and before building input_ids/batch_encoded) add logic that
raises a clear exception if n_chunks == 0 (mentioning tokenizer.encode,
token_stream, total_chunks, and max_sample_length) and log/warn when n_chunks <
total_chunks to inform the caller they received fewer chunks than requested;
ensure the exception/warning uses the existing logging mechanism or raises a
ValueError so callers cannot proceed with empty tensors.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: d3ed2625-6222-473c-80cd-32fb4d6fbd4c

📥 Commits

Reviewing files that changed from the base of the PR and between a451a2b and 20d3c5b.

📒 Files selected for processing (9)
  • CHANGELOG.rst
  • modelopt/torch/export/plugins/mcore_qwen.py
  • modelopt/torch/export/plugins/megatron_importer.py
  • modelopt/torch/nas/plugins/megatron.py
  • modelopt/torch/prune/plugins/mcore_minitron.py
  • modelopt/torch/utils/dataset_utils.py
  • modelopt/torch/utils/plugins/megatron_generate.py
  • modelopt/torch/utils/plugins/megatron_mmlu.py
  • tests/unit/torch/utils/test_dataset_utils.py

Comment thread modelopt/torch/prune/plugins/mcore_minitron.py
Comment thread modelopt/torch/utils/dataset_utils.py Outdated
@kevalmorabia97
Copy link
Copy Markdown
Collaborator Author

/claude review

@claude
Copy link
Copy Markdown

claude Bot commented May 15, 2026

Claude review summary

Findings: CRITICAL: 1, IMPORTANT: 2, SUGGESTION: 2

Most impactful

  1. Depth-pruning HybridModel silently keeps stale hybrid_override_pattern (mcore_minitron.py:372, :689). The PR registers HybridModel as a supported NAS root so convert_to_dynamic builds a real search space — but the depth-pruning post-processing still gates on isinstance(self.model, MambaModel). Per the PR description, newer Megatron-LM instantiates Nemotron-H et al. as plain HybridModel (a sibling/parent of MambaModel), so those checks miss the model and the saved checkpoint's layer pattern won't agree with the pruned config.num_layers. Width-only prune workflows (the path the PR description's Nemotron-H validation seems to exercise) are unaffected; depth pruning is the at-risk path.

  2. Qwen2.5 (and other GPT-family) import paths still hit the chance-accuracy bug. The fused-TE-spec rule keys (fused_input_layernorm / fused_pre_mlp_layernorm) are added only to qwen3_causal_lm_import; qwen25_causal_lm_import, mcore_llama, mcore_deepseek, mcore_gptoss still fall back to the old fused_norm-only path. HF naming is identical, so the same MMLU-at-chance failure is one config flip away on those archs.

  3. pack=True silently produces fewer chunks than requested when raw-text supply runs short of 2 × num_samples. A warn_rank_0 (or grow-the-oversample-factor loop) keeps this from being an undebuggable downstream ablation surprise.

Strengths

  • The fused-norm rule-key fallback (fused_input_layernormfused_norm) is the right shape — it's strictly additive for Nemotron-H and explicit-per-norm for GPT-family models, no behavior change for existing import paths.
  • The _DynamicTEQKVLayerNormColumnParallelLinear.in_features registration is consistent with the existing _DynamicTEParallelLinear pattern (both bind in_features to mod.input_size).
  • The .contiguous() fix on the megatron_prefill logits slice has a clear root-cause comment (SP padding + broadcast contiguity assert) and is a one-line, reversible change.
  • CHANGELOG entries are placed under correct sections in 0.45 with accurate scope.

Risk assessment

Medium. The Qwen3 / Nemotron-H prune paths the PR explicitly validated end-to-end look solid, and the calibration-packing primitive is well-isolated behind an opt-in flag. The pre-existing isinstance(model, MambaModel) checks in depth-pruning are the main concern — they should be widened to cover plain HybridModel in the same PR that promotes HybridModel to a first-class supported type, otherwise the bug "fixed" for convert_to_dynamic re-emerges in the depth-prune codepath.

Comment thread modelopt/torch/prune/plugins/mcore_minitron.py Outdated
Comment thread modelopt/torch/export/plugins/mcore_qwen.py
Comment thread modelopt/torch/utils/dataset_utils.py Outdated
Comment thread tests/unit/torch/utils/test_dataset_utils.py
Comment thread modelopt/torch/utils/dataset_utils.py Outdated
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/fix-mcore-minitron-hybrid-and-qwen-te-spec branch from 8b537ce to b70423f Compare May 16, 2026 11:10
@kevalmorabia97 kevalmorabia97 requested a review from a team as a code owner May 19, 2026 19:56
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/fix-mcore-minitron-hybrid-and-qwen-te-spec branch from d174a08 to ca3af81 Compare May 19, 2026 22:27
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/fix-mcore-minitron-hybrid-and-qwen-te-spec branch 2 times, most recently from c6e4e98 to 5e1b424 Compare May 20, 2026 16:36
@kevalmorabia97 kevalmorabia97 changed the title feat(utils): pack=True calibration mode for get_dataset_dataloader Create shared Megatron calibration forward loop for prune / quantize May 20, 2026
@kevalmorabia97
Copy link
Copy Markdown
Collaborator Author

@CodeRabbit review

@kevalmorabia97
Copy link
Copy Markdown
Collaborator Author

/claude review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 20, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review summary

Findings: CRITICAL: 1, IMPORTANT: 1, SUGGESTION: 3.

The shared calibration loop is a clean consolidation and the experimental data backs the design. One blocker, one back-compat side effect, and a few smaller notes — all in the new megatron_calibration.py.

Highest-impact

  • CRITICAL Algorithm — padding-direction assumption (megatron_calibration.py:131-144). The per-row trim slices ids[b, :real_len] and the no-padding branch overwrites ids[:, -1], both of which assume right-padded sequences. But get_dataset_dataloader (the function this loop calls) explicitly recommends padding_side="left" and warns when it isn't (dataset_utils.py:593-596). With left-padding, the trim selects padding tokens instead of real tokens — re-introducing the exact contamination the PR is trying to eliminate. Your Qwen3-8B numbers move in the right direction because that tokenizer happens to default to right-padding, but a user who follows the dataloader's documented recommendation silently regresses. Capture padding_side once and branch the slice (ids[b, -real_len:] for left-padding); also worth a regression test row whose attention_mask has zeros to lock the behavior in.

  • IMPORTANT Compatibility — caller-tokenizer mutation (megatron_calibration.py:87-88). Setting tokenizer.pad_token = tokenizer.eos_token mutates the caller's tokenizer object before the dataloader's deepcopy. The unwrapped tokenizer in prune_minitron.py is reused downstream (MMLU eval, export). Tokenizers whose pad_token was None will silently start emitting EOS as their pad token in every later code path that touches them. Local-copy the tokenizer before the mutation, or set the field on the dataloader-internal deepcopy only.

SUGGESTIONs (non-blocking)

  • Docstring claim that the loop matches GPTSFTDataset(add_eos=True) is slightly off — that flag appends EOS, the code overwrites the last real token. Reword or actually append when under seq_length.
  • bool((mask == 0).any().item()) and mask[b].sum().item() cause a CPU-GPU sync per batch / per row. Precompute real-lengths once on device into a CPU list at builder time to drop ~2*num_batches syncs.
  • The CP-rank slice inside _forward_loop is dead code in the tested configs (CP=1) and would actually break under CP>1 because different CP ranks would compute different per-row real_len and call megatron_prefill (a collective) with shape-mismatched inputs. Either run the trim before the CP slice, or drop the CP-shard call until CP>1 is validated.

Backward compatibility

  • Public API change: removing get_hf_mbridge_calibration_loop from modelopt.torch.utils.plugins.mbridge is a breaking change for any out-of-tree caller. Acceptable since the in-tree call sites are migrated and there's no schema/checkpoint impact, but flagging it for the maintainer's awareness.
  • examples/megatron_bridge/prune_minitron.py renames --calib_mbs / --calib_gbs to --calib_batch_size. Existing user scripts/CI that invoke this example with the old flags will fail at parse time. README is updated in the same PR — fine.

Risk

Low-to-moderate. The CRITICAL is a correctness contradiction with the dataloader the loop relies on; once that's reconciled and the tokenizer mutation is confined to a local copy, this is a clear quality win over the previous bespoke loops.

Comment thread modelopt/torch/utils/plugins/megatron_calibration.py Outdated
Comment thread modelopt/torch/utils/plugins/megatron_calibration.py
Comment thread modelopt/torch/utils/plugins/megatron_calibration.py Outdated
Comment thread modelopt/torch/utils/plugins/megatron_calibration.py
Comment thread modelopt/torch/utils/plugins/megatron_calibration.py Outdated
Replaces the bespoke calibration loops in M-LM and M-Bridge prune/quantize
example scripts with a single shared util,
``modelopt.torch.utils.plugins.megatron_calibration.get_megatron_calibration_forward_loop``.

The shared loop emits one sample per row (via ``get_dataset_dataloader``),
trims each row to its real length using the dataloader's attention mask, and
forces EOS at the trimmed last position before forwarding via
``megatron_prefill(skip_return_logits=True)``. Matches legacy
``GPTSFTDataset(add_eos=True)`` semantics exactly. Samples are sorted by real
length descending so front batches are mostly full-length (true batched
forward); back batches that contain padding fall through to per-row forward
to keep padding-token activations out of the calibration hook stream.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/fix-mcore-minitron-hybrid-and-qwen-te-spec branch from 5e1b424 to d1597cf Compare May 20, 2026 17:15
@kevalmorabia97
Copy link
Copy Markdown
Collaborator Author

/claude review

Addressed all 5 Claude review comments in megatron_calibration.py (commit d1597cf83):

  1. [CRITICAL Algorithm] right-padding assumption — replaced ids[b, :real_len] (which picked padding for left-padded tokenizers) with boolean-mask gather ids[b][mask[b].bool()]. Now works for both left- and right-padded tokenizers regardless of the dataloader's recommendation.

  2. [IMPORTANT Compatibility] mutates caller's tokenizertokenizer = copy.deepcopy(tokenizer) before setting pad_token. The caller's tokenizer (reused downstream for MMLU / export) is no longer silently modified.

  3. [SUGGESTION] CPU-GPU sync per batch — pre-compute per-row lengths once on CPU at builder time (lengths_cpu = cat_masks.sum(dim=-1).cpu()); the hot loop's has_padding check reads from a CPU tensor (no .item() sync per iteration). Per-row mask[b].sum() sync also gone — boolean gather doesn't need real_len directly.

  4. [SUGGESTION] docstring "matches add_eos=True" — docstring now clarifies that EOS overwrites the row's last real token (not appends), and explicitly notes the under-cap row trade-off (loses one content token for a consistent end-of-document marker).

  5. [SUGGESTION] CP > 1 ordering — moved get_batch_on_this_cp_rank to after per-row extraction (was before). Each row's full real-token sequence is CP-sliced just before forward, so all CP ranks agree on row length. CP=1 (what was tested) is unaffected.

Validation: re-ran M-LM prune on Qwen3-8B (nemotron-post-training-dataset-v2, seq=4096, calib=512, bs=16) post-fix → MMLU 0.584, vs 0.587 before fixes. Δ -0.3 pts, well within MMLU 5% noise floor (~±0.6 pt stderr). All fixes are quality-neutral (the boolean gather is numerically identical to [:real_len] slicing for right-padded tokenizers, which is what we tested).

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review summary

Findings: CRITICAL: 0, IMPORTANT: 2, SUGGESTION: 2

The refactor is well-motivated: the unified loop is simpler than the old M-Bridge evaluate_and_print_results pipeline, the per-row trim + EOS-at-end matches GPTSFTDataset(add_eos=True), and the sort-by-length / per-batch padding-fast-path is a clean way to get batched throughput without polluting calibration stats with padding. The MMLU numbers in the PR description back this up.

Most impactful findings

  1. Public API removal without deprecationget_hf_mbridge_calibration_loop is dropped from modelopt.torch.utils.plugins.mbridge.__all__ outright. Out-of-tree callers that pinned a previous example version will hit ImportError on upgrade with no migration hint. A one-release deprecation shim is cheap.

  2. inference_batch_size defaulting to calib_batch_size silently changes prune_target_memory_mb math. Previously --calib_mbs was hard-pinned to 1; now --calib_batch_size is unrestricted and the PR description recommends 16 for throughput. A user who follows that recommendation without also setting --inference_batch_size gets a 16× larger KV-cache term in the memory constraint — i.e., a different prune target — silently. Either default inference_batch_size to 1 unconditionally, or warn loudly when this divergence is implied.

Non-blocking

  1. CP > 1 is not really supported by the get_batch_on_this_cp_rank + megatron_prefill combination (megatron_prefill rebuilds attention_mask / position_ids from the locally-sliced length). Fine while the validated configs are CP=1, worth either an assert or removing the CP wrapper to avoid future foot-gun.
  2. Minor docstring nit: "no logits compute" overstates what skip_return_logits=True does.

Risk assessment

Algorithmic changes are validated by the MMLU sweeps in the PR description and the per-row trim / EOS / no-padding-in-calibration logic looks right. The risk surface is the user-facing surface: removed public symbol and the silent-default-shift on the prune memory target. Both are easy to address.

Comment thread examples/megatron_bridge/prune_minitron.py
Comment thread modelopt/torch/utils/plugins/megatron_calibration.py Outdated
Comment thread modelopt/torch/utils/plugins/megatron_calibration.py
- Assert CP=1 in the calibration forward loop (megatron_prefill builds
  causal mask + position_ids over the local input tensor length, which
  would silently produce wrong activations under CP>1). Calibration
  sequences are short enough that CP doesn't help anyway.
- Drop the get_batch_on_this_cp_rank call (was a no-op under CP=1 and
  broken under CP>1 — the per-row branch would have produced rank-local
  trimmed lengths, violating the collective-shape assumption).
- Docstring: "logits-free prefill" → "skips returning logits / loss
  compute; the LM head still runs and activation hooks still fire on
  every layer" (matches what skip_return_logits=True actually does).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97
Copy link
Copy Markdown
Collaborator Author

/claude review

Context for this round:

  • CP > 1 in calibration loop (commit 3426d728): added explicit assert mpu.get_context_parallel_world_size() == 1 in get_megatron_calibration_forward_loop._forward_loop and dropped the get_batch_on_this_cp_rank call. megatron_prefill builds its causal mask + position_ids over the local input tensor length, so splitting a calibration sequence across CP ranks would silently produce wrong activations. Calibration sequences are short — CP doesn't help — so failing loud is the right call.
  • Docstring accuracy (same commit): "logits-free prefill / no logits compute" reworded to "skips returning logits / loss compute; the LM head still runs and activation hooks still fire on every layer", matching what skip_return_logits=True actually does.
  • get_hf_mbridge_calibration_loop removal: this function was the legacy M-Bridge SFT-pipeline calibration loop. We deleted it from modelopt/torch/utils/plugins/mbridge.py because its only in-tree caller (examples/megatron_bridge/prune_minitron.py) is migrated to the shared loop in the same PR. Both caller and callee live in this repo — no external API contract to preserve.
  • inference_batch_size not pinned to 1: in prune_minitron.py, --inference_batch_size defaults to args.calib_batch_size (which itself defaults to 1) but a user can override it for memory-targeted NAS pruning. Leaving it un-pinned to 1 is desired — the inference batch size used for KV-cache memory sizing should be free to differ from the calibration micro-batch size.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude re-review

Findings (this pass): CRITICAL: 0, IMPORTANT: 2 (both carried over from prior review, still unresolved in code), SUGGESTION: 0.

Re-verification of prior findings

The single new commit (Address Claude review: CP-assert + clarify prefill docstring) addressed the two SUGGESTION items from the 2026-05-20 review:

  • ✅ CP > 1 footgun — runtime RuntimeError now raised inside _forward_loop when context_parallel_world_size != 1.
  • ✅ Docstring overstatement — module docstring now accurately describes skip_return_logits=True as "skips returning logits / loss compute … the LM head still runs and activation hooks still fire."

The two IMPORTANT items are not addressed:

  1. [IMPORTANT Compatibility] get_hf_mbridge_calibration_loop is removed from modelopt/torch/utils/plugins/mbridge.py:__all__ outright with the implementation deleted. Out-of-tree callers pinned to a previous release will hit ImportError on upgrade with no migration hint. A one-release deprecation shim (re-export plus DeprecationWarning pointing to get_megatron_calibration_forward_loop) would be cheap insurance.

  2. [IMPORTANT Compatibility] examples/megatron_bridge/prune_minitron.py:155-164 defaults --inference_batch_size to --calib_batch_size. Previously --calib_mbs was hard-pinned to choices=[1], so inference_batch_size always fell back to 1 — users could safely bump --calib_gbs for throughput without affecting the prune memory target. Now, the PR description recommends --calib_batch_size 16 (and the experimental table in the description used calib_bs 16 for the production runs), but anyone following that recommendation without also setting --inference_batch_size 1 will silently get a 16× larger KV-cache term in --prune_target_memory_mb — a different prune target than they intended. Default inference_batch_size=1 unconditionally, or warn when divergence is implied.

Algorithm / mode-state re-trace

End-to-end trace of the new get_megatron_calibration_forward_loop against current state:

  • Padding direction — boolean-mask gather (ids[b][mask[b].bool()]) correctly handles both left- and right-padded tokenizers; the prior padding-direction concern is fully addressed.
  • Tokenizer mutation — local copy.deepcopy(tokenizer) before tokenizer.pad_token = tokenizer.eos_token correctly isolates the mutation from the caller's tokenizer. get_dataset_dataloader does its own deepcopy, so the caller is double-protected.
  • CPU-GPU synclengths_cpu precomputed once on CPU; no per-batch sync inside the forward hot loop. ✓
  • Sort-by-length / per-batch padding-fast-path — for batches with no padding, batched forward; for any batch containing padding, falls through to per-row forward. Calibration statistics (amax / channel-importance) are order-invariant aggregates, so the re-ordering is bit-identical to un-sorted. ✓
  • PP coordination — all PP ranks see the same dataloader output (deterministic get_dataset_samples + shuffle=False) and call megatron_prefill with matching shapes per iteration. ✓
  • EOS-at-row-end semantics — overwrites the row's last real token with EOS for both the per-row and batched paths. Docstring is honest about the under-cap-row content-token loss trade-off. ✓
  • Loop invariants — partial-tail batches (n % batch_size != 0) handled correctly via for b in range(ids.shape[0]). Zero-real-token rows skipped via row.shape[1] < 1 continue.

Backward compatibility

  • CLI rename --calib_mbs/--calib_gbs--calib_batch_size (item 2 above).
  • Public API removal of get_hf_mbridge_calibration_loop (item 1 above).
  • No modelopt_state schema changes.

Risk

Low-to-moderate. The algorithm itself is sound and matches the experimental data in the PR description. Risk surface is the user-facing surface: a public-symbol removal with no shim, and a CLI default coupling that quietly changes the prune target when the user follows the PR's own throughput recommendation. Both are easy to address.

@kevalmorabia97 kevalmorabia97 requested review from yueshen2016 and removed request for AAnoosheh May 20, 2026 18:39
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants